import torch
import numpy as np
import click
from bgmol.datasets import AImplicitUnconstrained
import mdtraj as md
from bgflow.utils import (
    as_numpy,
)
from bgflow import XTBEnergy, XTBBridge
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class XTB_Target:
    def __init__(self, path):
        self.path = path
        self.dataset = AImplicitUnconstrained(read=True)

        self.target = self.dataset.get_energy_model()

        ala_traj = md.Trajectory(self.dataset.xyz, self.dataset.system.mdtraj_topology)

        atom_dict = {"H": 0, "C": 1, "N": 2, "O": 3}
        atom_types_xtb = []
        for atom_name in ala_traj.topology.atoms:
            atom_types_xtb.append(atom_name.name[0])
        atom_types_xtb = np.array(
            [atom_dict[atom_type] for atom_type in atom_types_xtb]
        )

        temperature = 300
        number_dict = {0: 1, 1: 6, 2: 7, 3: 8}
        numbers = np.array([number_dict[atom_type] for atom_type in atom_types_xtb])
        bridge_xtb = XTBBridge(
            numbers=numbers, temperature=temperature, solvent="water"
        )
        self.target_xtb = XTBEnergy(
            bridge_xtb,
            two_event_dims=False,
        )
        self.scaling = 10

    def load_data(self):
        return (
            torch.from_numpy(np.load(f"{self.path}/AD2_relaxed.npy")).float()
            / 10
            * self.scaling
        )

    def load_forces(self):
        # Load the forces from the path
        return torch.from_numpy(np.load(f"{self.path}/AD2_relaxed_force.npy")).float()

    def energy(self, sample):
        # Compute the energy using the target model
        return self.target_xtb.energy(sample.reshape((-1, 66)) / self.scaling)

    def force(self, sample):
        # Compute the forces using the target model
        return self.target_xtb.force(sample / self.scaling) / self.scaling

    def directional_force_check(self, sample, eps=1e-3, v=None):
        # Compute the forces using the target model
        v = torch.randn_like(sample) if v is None else v
        logp_p = self.energy((sample + eps * v))
        logp_m = self.energy((sample - eps * v))
        num_force = -(logp_p - logp_m) / (2 * eps)
        return num_force, v


@click.command()
@click.option("--data_path", default="/data")
def main(data_path):

    xtb_target = XTB_Target(data_path)
    data_smaller = xtb_target.load_data()
    xtb_forces = []
    for sample in tqdm(data_smaller):
        xtb_forces.append(as_numpy(xtb_target.force(sample)))

    xtb_forces = np.array(xtb_forces)
    print(xtb_forces.shape)

    np.save("AD2_relaxed_forces.npy", xtb_forces)
    print("You still have to put this file in the data folder")


if __name__ == "__main__":
    main()
